#!/usr/bin/env python
import argparse
import kalibr_common as kc
from mpl_toolkits.mplot3d import art3d, Axes3D, proj3d
import numpy as np
import pylab as pl
import sm


def parse_arguments():
    parser = argparse.ArgumentParser(
        description=
        'Visualizes a camchain-cam.yaml or camchain-imucam.yaml file.')
    parser.add_argument(
        '--cam',
        dest='chain_yaml',
        help=
        'Camera configuration as yaml file (can take multiple configurations).',
        nargs='+',
        required=True)
    parser.add_argument(
        '--azim',
        dest='azim_view',
        help='Azimuth angle of the view.',
        type=float,
        default=0.0,
        required=False)
    parser.add_argument(
        '--elev',
        dest='elev_view',
        help='Elevation angle of the view.',
        type=float,
        default=90.0,
        required=False)
    parser.add_argument(
        '--scaling',
        dest='scaling_factor',
        help='Scaling factor of the visualization (pixel size in meters).',
        type=float,
        default=1e-4,
        required=False)
    parsed_args = parser.parse_args()
    return parsed_args


def set_axis_equal_3d(ax):
    extents = np.array(
        [getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
    sz = extents[:, 1] - extents[:, 0]
    centers = np.mean(extents, axis=1)
    maxsize = max(abs(sz))
    r = maxsize / 2
    for ctr, dim in zip(centers, 'xyz'):
        getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r)
    ax.set_aspect('equal')


def set_orthogonal_projection(zfront, zback):
    a = (zfront + zback) / (zfront - zback)
    b = -2 * (zfront * zback) / (zfront - zback)
    return np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, a, b],
                     [0, 0, -1e-5, zback]])


def transform_point_cloud(transform, points):
    assert points.shape[1] == 3
    points_transformed = points.copy()
    for index in range(points.shape[0]):
        point_transformed = transform * points[index, :].reshape(3, 1)
        points_transformed[index, :] = point_transformed.reshape(1, 3)
    return points_transformed


class CalibrationVisualization:

    def __init__(self, scaling_factor):
        self.color_counter = 0
        self.fig = pl.figure()
        self.ax = self.fig.add_subplot(111, projection='3d')
        self.color_list = pl.cm.Pastel1(np.linspace(0, 1, 12))
        self.scaling_factor = scaling_factor

    def add_camera_plot(self, transform, intrinsics, resolution):
        focal_lengths = intrinsics[0:2]
        principal_point = intrinsics[2:4]
        self.draw_camera_outline(transform, focal_lengths, principal_point,
                                 resolution)
        self.add_coordinate_frame_plot(transform)

    def add_coordinate_frame_plot(self, transform):
        sm.plotCoordinateFrame(
            self.ax, transform.T(), size=3e2 * self.scaling_factor)

    def show_plot(self, azim, elev):
        self.ax.view_init(azim=azim, elev=elev)
        proj3d.persp_transformation = set_orthogonal_projection
        set_axis_equal_3d(self.ax)
        pl.show()

    def increment_color(self):
        self.color_counter += 1

    def draw_camera_outline(self, transform, focal_lengths, principal_point,
                            resolution):
        pixel_ratio = focal_lengths[0] / focal_lengths[1]
        scaled_focal_length = self.scaling_factor * focal_lengths[0]
        x_range = self.scaling_factor * np.array(
            [-principal_point[0], resolution[0] - principal_point[0]])
        y_range = pixel_ratio * self.scaling_factor * np.array(
            [-principal_point[1], resolution[1] - principal_point[1]])
        raw_camera_corners = np.array( \
                [[x_range[0], y_range[0], scaled_focal_length], \
                [x_range[1], y_range[0], scaled_focal_length], \
                [x_range[1], y_range[1], scaled_focal_length], \
                [x_range[0], y_range[1], scaled_focal_length], \
                [0.0, 0.0, 0.0]])
        camera_corners = transform_point_cloud(transform, raw_camera_corners)
        vertices = [[camera_corners[0], camera_corners[1], camera_corners[4]], \
                [camera_corners[0], camera_corners[3], camera_corners[4]], \
                [camera_corners[2], camera_corners[1], camera_corners[4]], \
                [camera_corners[2], camera_corners[3], camera_corners[4]], \
                [camera_corners[0], camera_corners[1], camera_corners[2], \
                    camera_corners[3] ]]
        fill_color = self.color_list[self.color_counter % 12]
        # Focal point.
        self.ax.scatter3D(
            camera_corners[4, 0],
            camera_corners[4, 1],
            camera_corners[4, 2],
            c=fill_color,
            s=3e5 * self.scaling_factor)
        # Image plane corners.
        self.ax.scatter3D(
            camera_corners[:, 0],
            camera_corners[:, 1],
            camera_corners[:, 2],
            c=fill_color,
            s=1e5 * self.scaling_factor)
        # Planes of camera models.
        self.ax.add_collection3d(
            art3d.Poly3DCollection(
                vertices,
                facecolors=fill_color,
                linewidths=1,
                edgecolors='k',
                alpha=.5,
                antialiased=True))


def main():
    parsed_args = parse_arguments()
    calibration_visualization = CalibrationVisualization(
        parsed_args.scaling_factor)
    for camchain_yaml in parsed_args.chain_yaml:
        camchain = kc.ConfigReader.CameraChainParameters(camchain_yaml)
        num_cameras = camchain.numCameras()
        print("Plotting configuration with " + str(num_cameras) + " cameras.")
        base_transform = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0],
                                   [0, 0, 0, 1]])
        T_c0_cn = sm.Transformation(base_transform)
        T_c0_cnm1 = sm.Transformation()
        for cam_index in range(num_cameras):
            cam = 'cam{0}'.format(cam_index)
            if cam != 'cam0':
                T_cnm1_cn = camchain.getExtrinsicsLastCamToHere(
                    cam_index).inverse()
                T_c0_cn = T_c0_cnm1 * T_cnm1_cn
            elif 'T_cam_imu' in camchain.data[cam]:
                T_c0_i = T_c0_cn * camchain.getExtrinsicsImuToCam(cam_index)
                calibration_visualization.add_coordinate_frame_plot(T_c0_i)
            intrinsics = np.asarray(camchain.data[cam]['intrinsics'])
            resolution = np.asarray(camchain.data[cam]['resolution'])
            if cam == 'cam0':
                calibration_visualization.add_coordinate_frame_plot(T_c0_cn)
            calibration_visualization.add_camera_plot(T_c0_cn, intrinsics,
                                                      resolution)
            T_c0_cnm1 = T_c0_cn
        calibration_visualization.increment_color()
    calibration_visualization.show_plot(parsed_args.azim_view,
                                        parsed_args.elev_view)


if __name__ == "__main__":
    main()
